import argparse
from typing import Optional

import orjsonl
from rich.box import ROUNDED
from rich.console import Console
from rich.table import Table
from sklearn.metrics import f1_score

from tqdm import tqdm

from utils import string_to_event_object
from schema_definition import acled_actor_field_names
from log_utils import get_logger

logger = get_logger(__name__)

console = Console()  # rich


def is_actor_field(arg_type: str) -> bool:
    return any(actor_field in arg_type for actor_field in acled_actor_field_names)


def add_f1_to_argument_stats(argument_stats):
    for arg_type, stats in argument_stats.items():
        tp = stats["tp"]
        fp = stats["fp"]
        fn = stats["fn"]

        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0

        if precision + recall > 0:
            f1 = 2 * (precision * recall) / (precision + recall)
        else:
            f1 = 0

        stats["f1"] = f1

    return argument_stats


def argument_stats_to_table(argument_stats: dict, title: str):
    # Create a new table for argument stats
    table = Table(
        title=title,
        box=ROUNDED,
        title_style="on blue",  # This adds a blue background to the title
    )
    table.add_column("Argument Type", style="cyan")
    table.add_column("Gold Count", style="blue", justify="right")
    table.add_column("F1 Score", style="magenta", justify="right")
    table.add_column("Precision", style="green", justify="right")
    table.add_column("Recall", style="yellow", justify="right")

    # Add rows to the table
    for arg_type, stats in argument_stats.items():
        f1 = stats["f1"]
        precision = (
            stats["tp"] / (stats["tp"] + stats["fp"])
            if stats["tp"] + stats["fp"] > 0
            else 0
        )
        recall = (
            stats["tp"] / (stats["tp"] + stats["fn"])
            if stats["tp"] + stats["fn"] > 0
            else 0
        )
        gold_count = stats["tp"] + stats["fn"]

        table.add_row(
            arg_type,
            f"{gold_count}",
            f"{f1 * 100:.1f}%",
            f"{precision * 100:.1f}%",
            f"{recall * 100:.1f}%",
        )
        stats["precision"] = precision
        stats["recall"] = recall

    # Calculate overall F1 score for arguments using argument_stats
    total_tp = sum(stats["tp"] for stats in argument_stats.values())
    total_fp = sum(stats["fp"] for stats in argument_stats.values())
    total_fn = sum(stats["fn"] for stats in argument_stats.values())

    if total_tp + total_fp > 0:
        precision = total_tp / (total_tp + total_fp)
    else:
        precision = 0

    if total_tp + total_fn > 0:
        recall = total_tp / (total_tp + total_fn)
    else:
        recall = 0

    if precision + recall > 0:
        f1_arguments = 2 * (precision * recall) / (precision + recall)
    else:
        f1_arguments = 0

    # Add a row for the total precision, recall, and F1 score
    table.add_row(None, None, None, None, None, style="dim")

    table.add_row(
        "Total",
        f"{sum(stats['tp'] + stats['fn'] for stats in argument_stats.values())}",
        f"{f1_arguments * 100:.1f}%",
        f"{precision * 100:.1f}%",
        f"{recall * 100:.1f}%",
    )

    return table


def main(
    input_file: str, seen_entities: set, include_event_type: bool, language: Optional[str]
):
    data = orjsonl.load(input_file)

    title_prefix = "w/ event types" if include_event_type else "w/o event types"

    syntax_errors = 0
    correct_predictions_event_type = 0
    gold_event_types = []
    predicted_event_types = []

    argument_stats = (
        {}
    )  # each item is an argument type (e.g. Attack.Location.country) and its tp, fp, fn
    seen_actor_stats = {"tp": 0, "fp": 0, "fn": 0}
    unseen_actor_stats = {"tp": 0, "fp": 0, "fn": 0}
    actor_linking_stats = {"tp": 0, "fp": 0, "fn": 0}

    for row in tqdm(data):
        if language and row["language"] != language:
            continue
        # article = row["input"]
        gold = row["output"]
        gold_event_object = string_to_event_object(gold)
        gold_event_type = gold_event_object.get_event_type()

        gold_arguments = set(
            gold_event_object.split_to_arguments(include_event_type=include_event_type)
        )
        # print(gold_arguments)
        prediction = row["prediction"]
        try:
            predicted_event_object = string_to_event_object(prediction)
        except Exception as e:
            logger.warning(
                "Error while converting model prediction to Pydantic object: %s", str(e)
            )
            syntax_errors += 1
            for arg in gold_arguments:
                arg_type, _ = tuple(arg.split("="))

                if arg_type not in argument_stats:
                    argument_stats[arg_type] = {"tp": 0, "fp": 0, "fn": 0}
                argument_stats[arg_type]["fn"] += 1
            continue

        if not predicted_event_object:
            logger.warning("Model prediction is empty: %s", prediction)
            continue

        gold_entities = set(gold_event_object.get_entities())
        predicted_entities = set(predicted_event_object.get_entities())

        # all entities
        actor_linking_stats["tp"] += len(gold_entities.intersection(predicted_entities))
        actor_linking_stats["fp"] += len(predicted_entities - gold_entities)
        actor_linking_stats["fn"] += len(gold_entities - predicted_entities)

        # seen entities
        seen_gold_entities = gold_entities.intersection(seen_entities)
        seen_predicted_entities = predicted_entities.intersection(seen_entities)
        seen_actor_stats["tp"] += len(
            seen_gold_entities.intersection(seen_predicted_entities)
        )
        seen_actor_stats["fp"] += len(seen_predicted_entities - seen_gold_entities)
        seen_actor_stats["fn"] += len(seen_gold_entities - seen_predicted_entities)

        # unseen entities
        unseen_gold_entities = gold_entities - seen_entities
        unseen_predicted_entities = predicted_entities - seen_entities
        unseen_actor_stats["tp"] += len(
            unseen_gold_entities.intersection(unseen_predicted_entities)
        )
        unseen_actor_stats["fp"] += len(unseen_predicted_entities - unseen_gold_entities)
        unseen_actor_stats["fn"] += len(unseen_gold_entities - unseen_predicted_entities)

        gold_event_types.append(gold_event_type)
        predicted_event_type = predicted_event_object.get_event_type()
        predicted_event_types.append(predicted_event_type)

        if gold_event_type == predicted_event_type:
            correct_predictions_event_type += 1

        # Calculate accuracy for event arguments
        predicted_arguments = set(
            predicted_event_object.split_to_arguments(
                include_event_type=include_event_type
            )
        )

        # Calculate arguments_stats for each event argument type
        for arg in gold_arguments.union(predicted_arguments):
            arg_type, arg_value = tuple(arg.split("=", 1))
            if arg_type not in argument_stats:
                argument_stats[arg_type] = {"tp": 0, "fp": 0, "fn": 0}

            if arg in gold_arguments and arg in predicted_arguments:
                argument_stats[arg_type]["tp"] += 1
            elif arg in predicted_arguments:
                argument_stats[arg_type]["fp"] += 1
            elif arg in gold_arguments:
                argument_stats[arg_type]["fn"] += 1

    # Argument metrics
    argument_stats = add_f1_to_argument_stats(argument_stats)
    sorted_arguments_stats = dict(
        sorted(argument_stats.items(), key=lambda x: x[1]["tp"] + x[1]["fn"])
    )

    if not include_event_type:
        # Global metrics
        syntax_errors /= len(data)
        em_event_type = correct_predictions_event_type / len(data)
        f1_event_type = f1_score(
            gold_event_types, predicted_event_types, average="micro"
        )

        table = Table(show_header=False, box=ROUNDED, title="Global Metrics")
        table.add_column("Metric", style="cyan")
        table.add_column("Value", style="magenta", justify="right")
        table.add_row("Syntax accuracy", f"{(1-syntax_errors) * 100:.1f} %")
        table.add_row("EM for event type", f"{em_event_type * 100:.1f} %")
        table.add_row("F1 for event type", f"{f1_event_type * 100:.1f} %")

        console.print("\n")
        console.print(table)

        # Location metrics
        location_args = {
            k: v for k, v in sorted_arguments_stats.items() if "location." in k
        }
        console.print(
            argument_stats_to_table(
                location_args, title=f"Location Arguments ({title_prefix})"
            )
        )

        # Entity linking metrics
        actor_linking_stats = {"Entity Linking": actor_linking_stats}
        actor_linking_stats = add_f1_to_argument_stats(actor_linking_stats)
        console.print(
            argument_stats_to_table(
                actor_linking_stats,
                title=f"Entity Linking ({title_prefix})",
            )
        )

        # Seen Entity metrics
        seen_actor_stats = {"Seen Entities": seen_actor_stats}
        seen_actor_stats = add_f1_to_argument_stats(seen_actor_stats)
        console.print(
            argument_stats_to_table(
                seen_actor_stats, title=f"Seen Entities ({title_prefix})"
            )
        )

        # Unseen Entity metrics
        unseen_actor_stats = {"Unseen Entities": unseen_actor_stats}
        unseen_actor_stats = add_f1_to_argument_stats(unseen_actor_stats)
        console.print(
            argument_stats_to_table(
                unseen_actor_stats, title=f"Unseen Entities ({title_prefix})"
            )
        )
    else:
        console.print(
            argument_stats_to_table(
                sorted_arguments_stats, title=f"Arguments ({title_prefix})"
            )
        )

        # Argument excluding entity metrics
        non_actor_args = {
            k: v for k, v in sorted_arguments_stats.items() if not is_actor_field(k)
        }
        console.print(
            argument_stats_to_table(
                non_actor_args, title=f"Non-Entity Arguments ({title_prefix})"
            )
        )

        # Entity metrics
        actor_args = {
            k: v for k, v in sorted_arguments_stats.items() if is_actor_field(k)
        }
        console.print(
            argument_stats_to_table(
                actor_args, title=f"Entity Arguments ({title_prefix})"
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_file", required=True, help="Path to the input JSONL file"
    )
    parser.add_argument(
        "--train_set",
        required=True,
        help="Path to the train set JSONL file. Used for extracting 'unseen' entities.",
    )
    parser.add_argument(
        "--language",
        type=str,
        default=None,
        help="Language code to evaluate. If None, will evaluate all languages.",
    )
    args = parser.parse_args()

    train_set = orjsonl.load(args.train_set)
    seen_entities = set()
    logger.info("Extracting seen entities from the training set...")
    for row in train_set:
        gold = string_to_event_object(row["gold_label"])
        seen_entities.update(gold.get_entities())

    logger.info(f"Found {len(seen_entities):,} seen entities in the training set.")

    for b in [True, False]:
        main(args.input_file, seen_entities, include_event_type=b, language=args.language)
        console.rule()
